# -*- coding: utf-8 -*-
import re
import numpy as np
import matplotlib.pyplot as plt
import os

RE = re.compile(r"/\d+")


class ObjFile:
    def __init__(self, obj_file=None):
        self.nodes = None
        self.faces = None
        if obj_file:
            self.ObjParse(obj_file)

    def ObjInfo(self):
        print("Num vertices  :    %d" % (len(self.nodes)))
        print("Num faces     :    %d" % (len(self.faces)))
        nmin, nmax = self.MinMaxNodes()
        print("Min/Max       :    %s %s" % (np.around(nmin, 3), np.around(nmax, 3)))

    @staticmethod
    def MinMax3d(arr):
        nmin = 1e9 * np.ones((3))
        nmax = -1e9 * np.ones((3))
        for a in arr:
            for i in range(3):
                nmin[i] = min(nmin[i], a[i])
                nmax[i] = max(nmax[i], a[i])
        return (nmin, nmax)

    def MinMaxNodes(self):
        return ObjFile.MinMax3d(self.nodes)

    def ObjParse(self, obj_file):
        f = open(obj_file)
        lines = f.readlines()
        f.close()
        nodes = []
        # add zero entry to get ids right
        nodes.append([0.0, 0.0, 0.0])
        faces = []
        for line in lines:
            if "v" == line[0] and line[1].isspace():  # do not match "vt" or "vn"
                v = line.split()
                nodes.append(ObjFile.ToFloats(v[1:])[:3])
            if "f" == line[0]:
                # remove /int
                line = re.sub(RE, "", line)
                f = line.split()
                faces.append(ObjFile.ToInts([s.split("/")[0] for s in f[1:]]))

        self.nodes = np.array(nodes)
        assert np.shape(self.nodes)[1] == 3
        self.faces = faces

    def ObjWrite(self, obj_file):
        f = open(obj_file, "w")
        for n in self.nodes[1:]:  # skip first dummy 'node'
            f.write("v ")
            for nn in n:
                f.write("%g " % (nn))
            f.write("\n")
        for ff in self.faces:
            f.write("f ")
            for fff in ff:
                f.write("%d " % (fff))
            f.write("\n")

    @staticmethod
    def ToFloats(n):
        if isinstance(n, list):
            v = []
            for nn in n:
                v.append(float(nn))
            return v
        else:
            return float(n)

    @staticmethod
    def ToInts(n):
        if isinstance(n, list):
            v = []
            for nn in n:
                v.append(int(nn))
            return v
        else:
            return int(n)

    @staticmethod
    def Normalize(v):
        v2 = np.linalg.norm(v)
        if v2 < 0.000000001:
            return v
        else:
            return v / v2

    def QuadToTria(self):
        trifaces = []
        for f in self.faces:
            if len(f) == 3:
                trifaces.append(f)
            elif len(f) == 4:
                f1 = [f[0], f[1], f[2]]
                f2 = [f[0], f[2], f[3]]
                trifaces.append(f1)
                trifaces.append(f2)
        return trifaces

    @staticmethod
    def ScaleVal(v, scale, minval=True):

        if minval:
            if v > 0:
                return v * (1.0 - scale)
            else:
                return v * scale
        else:  # maxval
            if v > 0:
                return v * scale
            else:
                return v * (1.0 - scale)
            
    def Plot(
        self,
        output_file=None,
        elevation=None,
        azim=None,
        width=None,
        height=None,
        xlim=None, 
        ylim=None,
        zlim=None,
        pose=None, 
    ):
        plt.ioff()
        tri = self.QuadToTria()
        fig = plt.figure(facecolor="black")
        ax = fig.add_subplot(111, projection="3d")
        ax.set_facecolor("black")
        ax.plot_trisurf(
            self.nodes[:, 0], self.nodes[:, 1], self.nodes[:, 2], triangles=tri, color="white"
        )
        ax.axis("off")
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
        limits = np.array([getattr(ax, f"get_{axis}lim")() for axis in "xyz"])
        ax.set_box_aspect(np.ptp(limits, axis=1))

        # Use provided axis limits if given, otherwise compute from nodes
        if xlim is not None and ylim is not None and zlim is not None:
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            ax.set_zlim(zlim)
        else:
            nmin, nmax = self.MinMaxNodes()
            ax.set_xlim(nmin[0], nmax[0])
            ax.set_ylim(nmin[1], nmax[1])
            ax.set_zlim(nmin[2], nmax[2])

        if pose is not None:
            # Extract elevation and azimuth from pose (approximate)
            t = pose[:3, 3]
            z_axis = pose[:3, 2]  # Camera forward direction
            elev = np.degrees(np.arcsin(z_axis[2]))
            azim = np.degrees(np.arctan2(z_axis[1], z_axis[0]))
            ax.view_init(elev, azim)
        elif elevation is not None and azim is not None:
            ax.view_init(elevation, azim)
        elif elevation is not None:
            ax.view_init(elevation, 30)
        elif azim is not None:
            ax.view_init(30, azim)
        else:
            ax.view_init(30, 30)

        if output_file:
            dpi = None
            if width and height:
                width_inches = 1.0 if width >= height else (width / height)
                height_inches = 1.0 if height >= width else (height / width)
                dpi = max(width, height)
                fig.set_size_inches(width_inches, height_inches)
            plt.savefig(output_file, dpi=dpi, facecolor="black", transparent=False)
            plt.close()

    def PlotDepth(
        self,
        output_file=None,
        elevation=None,
        azim=None,
        width=400,
        height=400,
        xlim=None,
        ylim=None,
        zlim=None,
        pose=None,
        K=None, 
        mask=None,
    ):
        plt.ioff()
        if K is None:
            fov = 60
            fx = fy = width / (2 * np.tan(np.radians(fov / 2)))
            K = np.array([[fx, 0, width/2],
                          [0, fy, height/2],
                          [0,  0,   1]], dtype=np.float32)

        if pose is None:
            if elevation is None or azim is None:
                raise ValueError("Either pose or both elevation and azim must be provided.")
            elev_rad = np.radians(elevation)
            azim_rad = np.radians(azim)
            radius = 1.5
            x = radius * np.cos(elev_rad) * np.cos(azim_rad)
            y = radius * np.cos(elev_rad) * np.sin(azim_rad)
            z = radius * np.sin(elev_rad)
            pos = np.array([x, y, z])
            z_axis = -pos / np.linalg.norm(pos)
            x_axis = np.cross([0, 1, 0], z_axis)
            x_axis /= np.linalg.norm(x_axis)
            y_axis = np.cross(z_axis, x_axis)
            R = np.stack([x_axis, y_axis, z_axis], axis=1)
            t = pos
            pose = np.eye(4)
            pose[:3, :3] = R
            pose[:3, 3] = t
            pose = pose[:3, :4]

        R = pose[:3, :3]
        t = pose[:3, 3]
        vertices_world = self.nodes[1:]
        vertices_camera = (R.T @ (vertices_world - t).T).T

        z = vertices_camera[:, 2]
        if np.all(z <= 0):
            print(f"Warning: All depth values are non-positive for {output_file}")
        xy = vertices_camera[:, :2] / z[:, None]
        uv = (K[:2, :2] @ xy.T + K[:2, 2:3]).T

        mask_bounds = (uv[:, 0] >= 0) & (uv[:, 0] < width) & (uv[:, 1] >= 0) & (uv[:, 1] < height) & (z > 0)
        points = uv[mask_bounds]
        values = z[mask_bounds]

        u, v = np.meshgrid(np.arange(width), np.arange(height))
        from scipy.interpolate import griddata
        
        if mask is not None:
            grid_points = np.stack([v.flatten(), u.flatten()], axis=1)
            depth_map = griddata(points, values, grid_points, method='linear', fill_value=np.inf)
            depth_map = depth_map.reshape(height, width)
            depth_map[~mask] = np.inf
        else:
            depth_map = griddata(points, values, (u, v), method='linear', fill_value=np.inf)

        if output_file:
            if output_file.endswith('.npy'):
                np.save(output_file, depth_map)
            else:
                finite_depth = depth_map[~np.isinf(depth_map)]
                if len(finite_depth) > 0:
                    dmin, dmax = finite_depth.min(), finite_depth.max()
                    depth_normalized = np.zeros_like(depth_map, dtype=np.float32)
                    mask_finite = ~np.isinf(depth_map)
                    depth_normalized[mask_finite] = 1.0 - (depth_map[mask_finite] - dmin) / (dmax - dmin)
                    depth_img = (255 * depth_normalized).astype(np.uint8)
                else:
                    depth_img = np.zeros((height, width), dtype=np.uint8)  # All black if no object

                fig = plt.figure(facecolor='black')
                ax = fig.add_subplot(111)
                ax.imshow(depth_img, cmap='gray', vmin=0, vmax=255)
                ax.axis('off')
                fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
                plt.savefig(output_file, dpi=width/8, facecolor='black', transparent=False)
                plt.close()


if __name__ == "__main__":
    obj_file = "../../partnet_data/555/objs/new-0.obj"
    out_file = "../../partnet_data/555/pngs/new-0.png"
    obj = ObjFile(obj_file)
    obj.Plot(
            out_file,
        )